-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1359] Adds a multiclass-MCC metric derived from Pearson #14461
Conversation
hmm, actually And of course this PR should have some tests added. I'll work through both these issues. |
39580ff
to
6f0a2f4
Compare
Thank you for your contribution @tlby! @mxnet-label-bot add[pr-work-in-progress] |
3d2c5de
to
64dde20
Compare
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 2a33cf4d9..6de76cc64 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -1576,9 +1576,8 @@ class PCC(EvalMetric):
n = max(pred.max(), label.max())
if n >= self.k:
self._grow(n + 1 - self.k)
- bcm = numpy.zeros((self.k, self.k))
- for i, j in zip(pred, label):
- bcm[i, j] += 1
+ ident = numpy.identity(self.k)
+ bcm = numpy.tensordot(ident[label], ident[pred].T, axes=(0,1))
self.lcm += bcm
self.gcm += bcm seems more efficient for constructing the confusion matrix, but benchmarks worse. I'm new to NumPy though, anyone see a better approach? |
I think this PR is ready now, sorry for posting the PR prematurely. I am happy with test coverage at this point, and happy with the metric scaling out to a large number of classes. The |
@mxnet-label-bot update[pr-awaiting-review] |
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index 2a33cf4d9..7bc090a0a 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -1576,9 +1576,8 @@ class PCC(EvalMetric):
n = max(pred.max(), label.max())
if n >= self.k:
self._grow(n + 1 - self.k)
- bcm = numpy.zeros((self.k, self.k))
- for i, j in zip(pred, label):
- bcm[i, j] += 1
+ k = self.k
+ bcm = numpy.bincount(label * k + pred, minlength=k*k).reshape((k, k))
self.lcm += bcm
self.gcm += bcm is a bit faster when k is small, but scales to larger k poorly. |
b514fee
to
7fb48d2
Compare
It turned out building the confusion matrix wasn't the hotspot in |
@szha Can you help with the review of this PR ? |
@tlby thanks for the contribution. Great job. |
…e#14461) * Adds a multiclass-MCC metric derived from Pearson * trigger ci
Description
A multiclass metric equivalent to
mxnet.metric.MCC
can be derived frommxnet.metric.PearsonCorrelation
with the addition of an.argmax()
on preds. I'd like to document this use case of Pearson and provide it behind a metric named "PCC" to simplify extending examples from F1 and MCC to multiclass predictions.Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments